import json
import os
import random
import time

from absl import app
from absl import flags
import numpy as np
import tensorflow as tf

from algorithms import adam
from algorithms import em
from algorithms import greedy
from algorithms import svd
from algorithms import svd_w

flags.DEFINE_integer('seed', 2023, 'Random seed')
flags.DEFINE_integer('rank', 10, 'Rank')
flags.DEFINE_integer('fisher_rank', 1, 'Rank for approximating Fisher matrix')
flags.DEFINE_string('fisher_file', 'fisher.npy', 'File with Fisher matrix')
flags.DEFINE_string('weight_file', 'weight.npy', 'File with weight matrix')
flags.DEFINE_string('output_file', 'results.json', 'File to write results')
FLAGS = flags.FLAGS
logger = tf.get_logger()


def weighted_loss(y_true, y_pred, weight):
    sqrt_weight = np.sqrt(weight)
    diff = sqrt_weight * (y_true - y_pred)
    return np.linalg.norm(diff)**2


def run(algo) -> list:
    logger.info('Loading matrices...')
    fisher = np.load(FLAGS.fisher_file)
    weight = np.load(FLAGS.weight_file)
    logger.info('Loaded Fisher matrix of size %d x %d', *fisher.shape)
    logger.info('Loaded weight matrix of size %d x %d', *weight.shape)
    rank = FLAGS.rank

    if algo == 'em':
        left_factor, right_factor, history = em.weighted_lra(
            weight, fisher, rank, return_history=True)
        loss = weighted_loss(weight, left_factor @ right_factor, fisher)
    elif algo == 'svd_w+em':
        inv_fisher_u, inv_fisher_v, weight_u, weight_v = svd_w.weighted_lra(
            weight, fisher, rank, fisher_rank=1
        )
        low_rank_inv_fisher = inv_fisher_u @ inv_fisher_v
        low_rank = weight_u @ weight_v
        # seed em with svd_w solution
        left_factor, right_factor, history = em.weighted_lra(
            weight, fisher, rank, initial_solution=low_rank_inv_fisher * low_rank, return_history=True
        )
        loss = weighted_loss(weight, left_factor @ right_factor, fisher)
    else:
        raise ValueError(f'Invalid algorithm {algo}')

    return history


def main(argv) -> None:
    del argv
    results = {}
    algos = ['em', 'svd_w+em']
    for algo in algos:
        history = run(algo)
        results[algo] = history
    with open(FLAGS.output_file, 'w') as fp:
        json.dump(results, fp)


if __name__ == '__main__':
    app.run(main)
